import pandas as pdInstall required packages to access images from S3 storage
!pip install -U imagecodecs s3fs tifffile
!git clone https://github.com/jump-cellpainting/JUMP-Target !git clone https://github.com/jump-cellpainting/datasets.git
Check the different data sources available in the CP-JUMP database
jump_plates_metadata = pd.read_csv("datasets/metadata/plate.csv.gz")
jump_plates_metadata| Metadata_Source | Metadata_Batch | Metadata_Plate | Metadata_PlateType | |
|---|---|---|---|---|
| 0 | source_1 | Batch1_20221004 | UL000109 | COMPOUND_EMPTY |
| 1 | source_1 | Batch1_20221004 | UL001641 | COMPOUND |
| 2 | source_1 | Batch1_20221004 | UL001643 | COMPOUND |
| 3 | source_1 | Batch1_20221004 | UL001645 | COMPOUND |
| 4 | source_1 | Batch1_20221004 | UL001651 | COMPOUND |
| ... | ... | ... | ... | ... |
| 2520 | source_9 | 20211103-Run16 | GR00004417 | COMPOUND |
| 2521 | source_9 | 20211103-Run16 | GR00004418 | COMPOUND |
| 2522 | source_9 | 20211103-Run16 | GR00004419 | COMPOUND |
| 2523 | source_9 | 20211103-Run16 | GR00004420 | COMPOUND |
| 2524 | source_9 | 20211103-Run16 | GR00004421 | COMPOUND |
2525 rows × 4 columns
jump_plates_metadata["Metadata_PlateType"].unique()array(['COMPOUND_EMPTY', 'COMPOUND', 'DMSO', 'TARGET2', 'CRISPR', 'ORF',
'TARGET1', 'POSCON8'], dtype=object)
jump_plates_metadata.groupby(["Metadata_Source", "Metadata_Batch"]).describe()| Metadata_Plate | Metadata_PlateType | ||||||||
|---|---|---|---|---|---|---|---|---|---|
| count | unique | top | freq | count | unique | top | freq | ||
| Metadata_Source | Metadata_Batch | ||||||||
| source_1 | Batch1_20221004 | 9 | 9 | UL000109 | 1 | 9 | 2 | COMPOUND | 6 |
| Batch2_20221006 | 7 | 7 | UL001647 | 1 | 7 | 1 | COMPOUND | 7 | |
| Batch3_20221010 | 8 | 8 | UL000087 | 1 | 8 | 1 | COMPOUND | 8 | |
| Batch4_20221012 | 8 | 8 | UL000081 | 1 | 8 | 1 | COMPOUND | 8 | |
| Batch5_20221030 | 11 | 11 | UL000561 | 1 | 11 | 2 | COMPOUND | 10 | |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| source_9 | 20210918-Run11 | 9 | 9 | GR00004367 | 1 | 9 | 2 | COMPOUND | 8 |
| 20210918-Run12 | 8 | 8 | GR00004377 | 1 | 8 | 1 | COMPOUND | 8 | |
| 20211013-Run14 | 13 | 13 | GR00003279 | 1 | 13 | 2 | COMPOUND | 12 | |
| 20211102-Run15 | 11 | 11 | GR00004391 | 1 | 11 | 2 | COMPOUND | 10 | |
| 20211103-Run16 | 17 | 17 | GR00004405 | 1 | 17 | 2 | COMPOUND | 16 | |
149 rows × 8 columns
crispr_wells_metadata = pd.read_csv("JUMP-Target/JUMP-Target-1_crispr_platemap.tsv", sep="\t")
crispr_wells_metadata["Plate_type"] = "CRISPR"
crispr_wells_metadata["Plate_label"] = 1
crispr_wells_metadata| well_position | broad_sample | Plate_type | Plate_label | |
|---|---|---|---|---|
| 0 | A01 | BRDN0001480888 | CRISPR | 1 |
| 1 | A02 | BRDN0001483495 | CRISPR | 1 |
| 2 | A03 | BRDN0001147364 | CRISPR | 1 |
| 3 | A04 | BRDN0001490272 | CRISPR | 1 |
| 4 | A05 | BRDN0001480510 | CRISPR | 1 |
| ... | ... | ... | ... | ... |
| 379 | P20 | BRDN0001145303 | CRISPR | 1 |
| 380 | P21 | BRDN0001484228 | CRISPR | 1 |
| 381 | P22 | BRDN0001487618 | CRISPR | 1 |
| 382 | P23 | BRDN0001487864 | CRISPR | 1 |
| 383 | P24 | BRDN0000735603 | CRISPR | 1 |
384 rows × 4 columns
orf_wells_metadata = pd.read_csv("JUMP-Target/JUMP-Target-1_orf_platemap.tsv", sep="\t")
orf_wells_metadata["Plate_type"] = "ORF"
orf_wells_metadata["Plate_label"] = 2
orf_wells_metadata| well_position | broad_sample | Plate_type | Plate_label | |
|---|---|---|---|---|
| 0 | A01 | ccsbBroad304_00900 | ORF | 2 |
| 1 | A02 | ccsbBroad304_07795 | ORF | 2 |
| 2 | A03 | ccsbBroad304_02826 | ORF | 2 |
| 3 | A04 | ccsbBroad304_01492 | ORF | 2 |
| 4 | A05 | ccsbBroad304_00691 | ORF | 2 |
| ... | ... | ... | ... | ... |
| 379 | P20 | ccsbBroad304_00277 | ORF | 2 |
| 380 | P21 | ccsbBroad304_06464 | ORF | 2 |
| 381 | P22 | ccsbBroad304_00476 | ORF | 2 |
| 382 | P23 | ccsbBroad304_01649 | ORF | 2 |
| 383 | P24 | ccsbBroad304_03934 | ORF | 2 |
384 rows × 4 columns
compound_wells_metadata = pd.read_csv("JUMP-Target/JUMP-Target-1_compound_platemap.tsv", sep="\t")
compound_wells_metadata["Plate_type"] = "COMPOUND"
compound_wells_metadata["Plate_label"] = 3
compound_wells_metadata| well_position | broad_sample | solvent | Plate_type | Plate_label | |
|---|---|---|---|---|---|
| 0 | A01 | BRD-A86665761-001-01-1 | DMSO | COMPOUND | 3 |
| 1 | A02 | NaN | DMSO | COMPOUND | 3 |
| 2 | A03 | BRD-A22032524-074-09-9 | DMSO | COMPOUND | 3 |
| 3 | A04 | BRD-A01078468-001-14-8 | DMSO | COMPOUND | 3 |
| 4 | A05 | BRD-K48278478-001-01-2 | DMSO | COMPOUND | 3 |
| ... | ... | ... | ... | ... | ... |
| 379 | P20 | BRD-K68982262-001-01-4 | DMSO | COMPOUND | 3 |
| 380 | P21 | BRD-K24616672-003-20-1 | DMSO | COMPOUND | 3 |
| 381 | P22 | BRD-A82396632-008-30-8 | DMSO | COMPOUND | 3 |
| 382 | P23 | BRD-K61250553-003-30-6 | DMSO | COMPOUND | 3 |
| 383 | P24 | BRD-K70358946-001-17-3 | DMSO | COMPOUND | 3 |
384 rows × 5 columns
wells_metadata = pd.concat([compound_wells_metadata, orf_wells_metadata, crispr_wells_metadata])
wells_metadata| well_position | broad_sample | solvent | Plate_type | Plate_label | |
|---|---|---|---|---|---|
| 0 | A01 | BRD-A86665761-001-01-1 | DMSO | COMPOUND | 3 |
| 1 | A02 | NaN | DMSO | COMPOUND | 3 |
| 2 | A03 | BRD-A22032524-074-09-9 | DMSO | COMPOUND | 3 |
| 3 | A04 | BRD-A01078468-001-14-8 | DMSO | COMPOUND | 3 |
| 4 | A05 | BRD-K48278478-001-01-2 | DMSO | COMPOUND | 3 |
| ... | ... | ... | ... | ... | ... |
| 379 | P20 | BRDN0001145303 | NaN | CRISPR | 1 |
| 380 | P21 | BRDN0001484228 | NaN | CRISPR | 1 |
| 381 | P22 | BRDN0001487618 | NaN | CRISPR | 1 |
| 382 | P23 | BRDN0001487864 | NaN | CRISPR | 1 |
| 383 | P24 | BRDN0000735603 | NaN | CRISPR | 1 |
1152 rows × 5 columns
wells_metadata.loc[wells_metadata["broad_sample"].isna(), "Plate_label"] = 0wells_metadata| well_position | broad_sample | solvent | Plate_type | Plate_label | |
|---|---|---|---|---|---|
| 0 | A01 | BRD-A86665761-001-01-1 | DMSO | COMPOUND | 3 |
| 1 | A02 | NaN | DMSO | COMPOUND | 0 |
| 2 | A03 | BRD-A22032524-074-09-9 | DMSO | COMPOUND | 3 |
| 3 | A04 | BRD-A01078468-001-14-8 | DMSO | COMPOUND | 3 |
| 4 | A05 | BRD-K48278478-001-01-2 | DMSO | COMPOUND | 3 |
| ... | ... | ... | ... | ... | ... |
| 379 | P20 | BRDN0001145303 | NaN | CRISPR | 1 |
| 380 | P21 | BRDN0001484228 | NaN | CRISPR | 1 |
| 381 | P22 | BRDN0001487618 | NaN | CRISPR | 1 |
| 382 | P23 | BRDN0001487864 | NaN | CRISPR | 1 |
| 383 | P24 | BRDN0000735603 | NaN | CRISPR | 1 |
1152 rows × 5 columns
Review information related to each perturbation in the Broad Institute Genetic Perturbation Platform (https://portals.broadinstitute.org/gpp/public/)
Review the CP-JUMP data directly from the AWS bucket
The Cell Painting Image Collection Registry of Open Data on AWS (https://registry.opendata.aws/cellpainting-gallery/) is a collection of microscopy image sets.
The AWS bucket can be found here: https://cellpainting-gallery.s3.amazonaws.com/index.html
Get the URL of each assay plate from the bucket
import s3fsfs = s3fs.S3FileSystem(anon=True)
batch_names = {}
plate_paths = {}
source_names = {}
plate_types = {}
for _, src_row in jump_plates_metadata.groupby(["Metadata_Source", "Metadata_Batch"]).describe().iterrows():
source_name, batch_name = src_row.name
# Ignore 'source_8' since the naming of the images is not standard
if source_name in ["source_8"]:
continue
plate_type = src_row["Metadata_PlateType"].top
for plate_path in fs.ls(f"cellpainting-gallery/cpg0016-jump/{source_name}/images/{batch_name}/images/"):
plate_path = plate_path.split("/")[-1]
if not plate_path:
continue
plate_name = plate_path.split("__")[0]
source_names[plate_name] = source_name
batch_names[plate_name] = batch_name
plate_types[plate_name] = plate_type
plate_paths[plate_name] = plate_pathplate_maps = pd.DataFrame()
plate_maps["Plate_name"] = batch_names.keys()
plate_maps["Source_name"] = plate_maps["Plate_name"].map(source_names)
plate_maps["Batch_name"] = plate_maps["Plate_name"].map(batch_names)
plate_maps["Plate_type"] = plate_maps["Plate_name"].map(plate_types)
plate_maps["Plate_path"] = plate_maps["Plate_name"].map(plate_paths)plate_maps| Plate_name | Source_name | Batch_name | Plate_type | Plate_path | |
|---|---|---|---|---|---|
| 0 | UL000109 | source_1 | Batch1_20221004 | COMPOUND | UL000109__2022-10-05T06_35_06-Measurement1 |
| 1 | UL001641 | source_1 | Batch1_20221004 | COMPOUND | UL001641__2022-10-04T23_16_28-Measurement1 |
| 2 | UL001643 | source_1 | Batch1_20221004 | COMPOUND | UL001643__2022-10-04T18_52_42-Measurement2 |
| 3 | UL001645 | source_1 | Batch1_20221004 | COMPOUND | UL001645__2022-10-05T00_44_11-Measurement1 |
| 4 | UL001651 | source_1 | Batch1_20221004 | COMPOUND | UL001651__2022-10-04T20_20_52-Measurement1 |
| ... | ... | ... | ... | ... | ... |
| 2333 | GR00004417 | source_9 | 20211103-Run16 | COMPOUND | GR00004417 |
| 2334 | GR00004418 | source_9 | 20211103-Run16 | COMPOUND | GR00004418 |
| 2335 | GR00004419 | source_9 | 20211103-Run16 | COMPOUND | GR00004419 |
| 2336 | GR00004420 | source_9 | 20211103-Run16 | COMPOUND | GR00004420 |
| 2337 | GR00004421 | source_9 | 20211103-Run16 | COMPOUND | GR00004421 |
2338 rows × 5 columns
comp_plate_maps = plate_maps.query("Plate_type=='COMPOUND'")
comp_plate_maps| Plate_name | Source_name | Batch_name | Plate_type | Plate_path | |
|---|---|---|---|---|---|
| 0 | UL000109 | source_1 | Batch1_20221004 | COMPOUND | UL000109__2022-10-05T06_35_06-Measurement1 |
| 1 | UL001641 | source_1 | Batch1_20221004 | COMPOUND | UL001641__2022-10-04T23_16_28-Measurement1 |
| 2 | UL001643 | source_1 | Batch1_20221004 | COMPOUND | UL001643__2022-10-04T18_52_42-Measurement2 |
| 3 | UL001645 | source_1 | Batch1_20221004 | COMPOUND | UL001645__2022-10-05T00_44_11-Measurement1 |
| 4 | UL001651 | source_1 | Batch1_20221004 | COMPOUND | UL001651__2022-10-04T20_20_52-Measurement1 |
| ... | ... | ... | ... | ... | ... |
| 2333 | GR00004417 | source_9 | 20211103-Run16 | COMPOUND | GR00004417 |
| 2334 | GR00004418 | source_9 | 20211103-Run16 | COMPOUND | GR00004418 |
| 2335 | GR00004419 | source_9 | 20211103-Run16 | COMPOUND | GR00004419 |
| 2336 | GR00004420 | source_9 | 20211103-Run16 | COMPOUND | GR00004420 |
| 2337 | GR00004421 | source_9 | 20211103-Run16 | COMPOUND | GR00004421 |
1905 rows × 5 columns
pert_plate_maps = plate_maps[plate_maps["Plate_type"].isin(["CRISPR", "ORF", "DMSO"])]
pert_plate_maps| Plate_name | Source_name | Batch_name | Plate_type | Plate_path | |
|---|---|---|---|---|---|
| 142 | Dest210628-161651 | source_10 | 2021_06_28_U2OS_48_hr_run9 | DMSO | Dest210628-161651 |
| 143 | Dest210628-162003 | source_10 | 2021_06_28_U2OS_48_hr_run9 | DMSO | Dest210628-162003 |
| 457 | CP-CC9-R1-01 | source_13 | 20220914_Run1 | CRISPR | CP-CC9-R1-01 |
| 458 | CP-CC9-R1-02 | source_13 | 20220914_Run1 | CRISPR | CP-CC9-R1-02 |
| 459 | CP-CC9-R1-03 | source_13 | 20220914_Run1 | CRISPR | CP-CC9-R1-03 |
| ... | ... | ... | ... | ... | ... |
| 1591 | BR00127145 | source_4 | 2021_08_30_Batch13 | ORF | BR00127145__2021-09-22T04_01_46-Measurement1 |
| 1592 | BR00127146 | source_4 | 2021_08_30_Batch13 | ORF | BR00127146__2021-09-22T12_25_07-Measurement1 |
| 1593 | BR00127147 | source_4 | 2021_08_30_Batch13 | ORF | BR00127147__2021-09-18T10_27_12-Measurement1 |
| 1594 | BR00127148 | source_4 | 2021_08_30_Batch13 | ORF | BR00127148__2021-09-21T11_44_23-Measurement1 |
| 1595 | BR00127149 | source_4 | 2021_08_30_Batch13 | ORF | BR00127149__2021-09-18T02_10_04-Measurement1 |
433 rows × 5 columns
pert_plate_maps["Source_name"].unique()array(['source_10', 'source_13', 'source_4'], dtype=object)
comp_plate_maps["Source_name"].unique()array(['source_1', 'source_10', 'source_11', 'source_15', 'source_2',
'source_3', 'source_5', 'source_6', 'source_7', 'source_9'],
dtype=object)
Split the dataset into Training, Validation, and Test sets
import random
import mathtrn_plates = []
val_plates = []
tst_plates = []
trn_proportion = 0.7
val_proportion = 0.2
tst_proportion = 0.1
for batch_name in pert_plate_maps["Batch_name"].unique():
plate_names = pert_plate_maps.query(f"Batch_name == '{batch_name}'")["Plate_name"].tolist()
random.shuffle(plate_names)
tst_plates_count = int(math.ceil(len(plate_names) * tst_proportion))
val_plates_count = int(math.ceil(len(plate_names) * val_proportion))
tst_plates += plate_names[:tst_plates_count]
val_plates += plate_names[tst_plates_count:tst_plates_count + val_plates_count]
trn_plates += plate_names[tst_plates_count + val_plates_count:]trn_plates[:5]['CP-CC9-R1-16',
'CP-CC9-R1-22',
'CP-CC9-R1-26',
'CP-CC9-R1-12',
'CP-CC9-R1-29']
val_plates[:5]['Dest210628-162003',
'CP-CC9-R1-13',
'CP-CC9-R1-28',
'CP-CC9-R1-17',
'CP-CC9-R1-24']
tst_plates[:5]['Dest210628-161651',
'CP-CC9-R1-08',
'CP-CC9-R1-01',
'CP-CC9-R1-07',
'CP-CC9-R2-22']
print("Training set size:", len(trn_plates))
print("Validation set size:", len(val_plates))
print("Testing set size:", len(tst_plates))Training set size: 283
Validation set size: 96
Testing set size: 54
Create a Dataset that can be used with PyTorch
# @title Definition of a Dataset class capable to pull images from AWS S3 buckets
import random
import numpy as np
import string
import s3fs
from itertools import product
from PIL import Image
import tifffile
from torch.utils.data import IterableDataset, get_worker_info
from time import perf_counter
def s3dataset_worker_init_fn(worker_id):
"""ZarrDataset multithread workers initialization function.
"""
worker_info = torch.utils.data.get_worker_info()
w_sel = slice(worker_id, None, worker_info.num_workers)
dataset_obj = worker_info.dataset
# Reset the random number generators in each worker.
torch_seed = torch.initial_seed()
dataset_obj._worker_sel = w_sel
dataset_obj._worker_id = worker_id
dataset_obj._num_workers = worker_info.num_workers
def load_well(plate_metadata, well_row, well_col, field_id, channels, s3):
# Get the label of the current well
curr_well_image = []
plate_path = "cellpainting-gallery/cpg0016-jump/" + plate_metadata["Source_name"] + "/images/" + plate_metadata["Batch_name"] + "/images/" + plate_metadata["Plate_path"]
for channel_id in range(channels):
if plate_metadata["Source_name"] in ["source_1", "source_3", "source_4", "source_9", "source_11", "source_15"]:
image_suffix = f"Images/r{well_row + 1:02d}c{well_col + 1:02d}f{field_id + 1:02d}p01-ch{channel_id + 1}sk1fk1fl1.tiff"
else:
if plate_metadata["Source_name"] in ["source_2", "source_5"]:
a_locs = [1, 2, 3, 4, 5]
elif plate_metadata["Source_name"] in ["source_6", "source_10"]:
a_locs = [1, 2, 2, 3, 1, 4]
elif plate_metadata["Source_name"] in ["source_7", "source_13"]:
a_locs = [1, 1, 2, 3, 4]
image_suffix = f"{plate_metadata["Plate_name"]}_{string.ascii_uppercase[well_row]}{well_col + 1:02d}_T0001F{field_id + 1:03d}L01A{a_locs[channel_id]:02d}Z01C{channel_id + 1:02d}.tif"
image_url = "s3://" + plate_path + "/" + image_suffix
try:
with s3.open(image_url, 'rb') as f:
curr_image = tifffile.imread(f)
except FileNotFoundError:
print("Failed retrieving:", image_url)
return None
curr_image = curr_image.astype(np.float32)
curr_image /= 2 ** 16 - 1
curr_well_image.append(curr_image)
curr_well_image = np.array(curr_well_image)
return curr_well_image
class TiffS3Dataset(IterableDataset):
"""This dataset could have virtually infinite samples.
"""
def __init__(self, plate_maps, wells_metadata, plate_names, well_rows=24, well_cols=16, fields=4, channels=5, shuffle=False):
super(TiffS3Dataset).__init__()
self._plate_maps = plate_maps
self._wells_metadata = wells_metadata
self._plate_names = plate_names
self._well_rows = well_rows
self._well_cols = well_cols
self._fields = fields
self._channels = channels
self._shuffle = shuffle
self._worker_sel = slice(0, len(plate_names) * self._well_rows * self._well_cols)
self._worker_id = 0
self._num_workers = 1
self._s3 = None
def __iter__(self):
# Select the barcodes that correspond to this worker
self._s3 = s3fs.S3FileSystem(anon=True)
self._plate_names = self._plate_names[self._worker_sel]
well_row_range = range(self._well_rows)
well_col_range = range(self._well_cols)
fields_range = range(self._fields)
for plate_name, well_row, well_col, field_id in product(self._plate_names, well_row_range, well_col_range, fields_range):
if self._shuffle:
plate_name = random.choice(self._plate_names)
well_row = random.randrange(self._well_rows)
well_col = random.randrange(self._well_cols)
field_id = random.randrange(self._fields)
curr_plate_map = self._plate_maps.query(f"Plate_name == '{plate_name}'")
curr_plate_metadata = curr_plate_map.to_dict(orient='records')[0]
if not len(curr_plate_metadata):
continue
curr_image = load_well(curr_plate_metadata, well_row, well_col, field_id, self._channels, self._s3)
if curr_image is None:
continue
curr_image = curr_image[:, :1080, :1080]
_, h, w = curr_image.shape
pad_h = 1080 - h
pad_w = 1080 - w
if pad_h or pad_w:
curr_image = np.pad(curr_image, ((0, 0), (0, pad_h), (0, pad_w)))
if curr_plate_metadata["Plate_type"] == "DMSO":
curr_label = 0
else:
curr_label = self._wells_metadata.query(f"Plate_type=='{curr_plate_metadata["Plate_type"]}' & well_position=='{string.ascii_uppercase[well_row]}{well_col + 1:02d}'")["Plate_label"]
if not len(curr_label):
continue
curr_label = curr_label.item()
yield curr_image, curr_label, curr_plate_metadata
self._s3 = NoneCreate the datasets from the list of URLs
training_ds = TiffS3Dataset(pert_plate_maps, wells_metadata, trn_plates, 16, 24, 9, 5, shuffle=True)
validation_ds = TiffS3Dataset(pert_plate_maps, wells_metadata, val_plates, 16, 24, 9, 5, shuffle=True)
testing_ds = TiffS3Dataset(pert_plate_maps, wells_metadata, tst_plates, 16, 24, 9, 5, shuffle=True)Import a pre-trained model from torchvision
import torch
from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weightsweights = MobileNet_V3_Small_Weights.DEFAULT
model = mobilenet_v3_small(weights=weights)Change the last layers of the pre-trained model to convert it into a feature extraction function
org_avgpool = model.avgpool
model.avgpool = torch.nn.Identity()
model.classifier = torch.nn.Identity()
model.cuda()
model.eval()MobileNetV3(
(features): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(1): InvertedResidual(
(block): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=16, bias=False)
(1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(1): SqueezeExcitation(
(avgpool): AdaptiveAvgPool2d(output_size=1)
(fc1): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))
(fc2): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
(activation): ReLU()
(scale_activation): Hardsigmoid()
)
(2): Conv2dNormActivation(
(0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
)
)
)
(2): InvertedResidual(
(block): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(16, 72, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(72, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(1): Conv2dNormActivation(
(0): Conv2d(72, 72, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=72, bias=False)
(1): BatchNorm2d(72, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(2): Conv2dNormActivation(
(0): Conv2d(72, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(24, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
)
)
)
(3): InvertedResidual(
(block): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(24, 88, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(88, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(1): Conv2dNormActivation(
(0): Conv2d(88, 88, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=88, bias=False)
(1): BatchNorm2d(88, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(2): Conv2dNormActivation(
(0): Conv2d(88, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(24, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
)
)
)
(4): InvertedResidual(
(block): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(24, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(96, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(1): Conv2dNormActivation(
(0): Conv2d(96, 96, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), groups=96, bias=False)
(1): BatchNorm2d(96, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(2): SqueezeExcitation(
(avgpool): AdaptiveAvgPool2d(output_size=1)
(fc1): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1))
(fc2): Conv2d(24, 96, kernel_size=(1, 1), stride=(1, 1))
(activation): ReLU()
(scale_activation): Hardsigmoid()
)
(3): Conv2dNormActivation(
(0): Conv2d(96, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(40, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
)
)
)
(5): InvertedResidual(
(block): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(240, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(1): Conv2dNormActivation(
(0): Conv2d(240, 240, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=240, bias=False)
(1): BatchNorm2d(240, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(2): SqueezeExcitation(
(avgpool): AdaptiveAvgPool2d(output_size=1)
(fc1): Conv2d(240, 64, kernel_size=(1, 1), stride=(1, 1))
(fc2): Conv2d(64, 240, kernel_size=(1, 1), stride=(1, 1))
(activation): ReLU()
(scale_activation): Hardsigmoid()
)
(3): Conv2dNormActivation(
(0): Conv2d(240, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(40, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
)
)
)
(6): InvertedResidual(
(block): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(240, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(1): Conv2dNormActivation(
(0): Conv2d(240, 240, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=240, bias=False)
(1): BatchNorm2d(240, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(2): SqueezeExcitation(
(avgpool): AdaptiveAvgPool2d(output_size=1)
(fc1): Conv2d(240, 64, kernel_size=(1, 1), stride=(1, 1))
(fc2): Conv2d(64, 240, kernel_size=(1, 1), stride=(1, 1))
(activation): ReLU()
(scale_activation): Hardsigmoid()
)
(3): Conv2dNormActivation(
(0): Conv2d(240, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(40, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
)
)
)
(7): InvertedResidual(
(block): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(40, 120, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(120, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(1): Conv2dNormActivation(
(0): Conv2d(120, 120, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=120, bias=False)
(1): BatchNorm2d(120, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(2): SqueezeExcitation(
(avgpool): AdaptiveAvgPool2d(output_size=1)
(fc1): Conv2d(120, 32, kernel_size=(1, 1), stride=(1, 1))
(fc2): Conv2d(32, 120, kernel_size=(1, 1), stride=(1, 1))
(activation): ReLU()
(scale_activation): Hardsigmoid()
)
(3): Conv2dNormActivation(
(0): Conv2d(120, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
)
)
)
(8): InvertedResidual(
(block): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(48, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(144, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(1): Conv2dNormActivation(
(0): Conv2d(144, 144, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=144, bias=False)
(1): BatchNorm2d(144, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(2): SqueezeExcitation(
(avgpool): AdaptiveAvgPool2d(output_size=1)
(fc1): Conv2d(144, 40, kernel_size=(1, 1), stride=(1, 1))
(fc2): Conv2d(40, 144, kernel_size=(1, 1), stride=(1, 1))
(activation): ReLU()
(scale_activation): Hardsigmoid()
)
(3): Conv2dNormActivation(
(0): Conv2d(144, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
)
)
)
(9): InvertedResidual(
(block): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(48, 288, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(288, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(1): Conv2dNormActivation(
(0): Conv2d(288, 288, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), groups=288, bias=False)
(1): BatchNorm2d(288, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(2): SqueezeExcitation(
(avgpool): AdaptiveAvgPool2d(output_size=1)
(fc1): Conv2d(288, 72, kernel_size=(1, 1), stride=(1, 1))
(fc2): Conv2d(72, 288, kernel_size=(1, 1), stride=(1, 1))
(activation): ReLU()
(scale_activation): Hardsigmoid()
)
(3): Conv2dNormActivation(
(0): Conv2d(288, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(96, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
)
)
)
(10): InvertedResidual(
(block): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(576, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(1): Conv2dNormActivation(
(0): Conv2d(576, 576, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=576, bias=False)
(1): BatchNorm2d(576, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(2): SqueezeExcitation(
(avgpool): AdaptiveAvgPool2d(output_size=1)
(fc1): Conv2d(576, 144, kernel_size=(1, 1), stride=(1, 1))
(fc2): Conv2d(144, 576, kernel_size=(1, 1), stride=(1, 1))
(activation): ReLU()
(scale_activation): Hardsigmoid()
)
(3): Conv2dNormActivation(
(0): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(96, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
)
)
)
(11): InvertedResidual(
(block): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(576, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(1): Conv2dNormActivation(
(0): Conv2d(576, 576, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=576, bias=False)
(1): BatchNorm2d(576, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
(2): SqueezeExcitation(
(avgpool): AdaptiveAvgPool2d(output_size=1)
(fc1): Conv2d(576, 144, kernel_size=(1, 1), stride=(1, 1))
(fc2): Conv2d(144, 576, kernel_size=(1, 1), stride=(1, 1))
(activation): ReLU()
(scale_activation): Hardsigmoid()
)
(3): Conv2dNormActivation(
(0): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(96, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
)
)
)
(12): Conv2dNormActivation(
(0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(576, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): Hardswish()
)
)
(avgpool): Identity()
(classifier): Identity()
)
model_transforms = weights.transforms()
model_transformsImageClassification(
crop_size=[224]
resize_size=[256]
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
interpolation=InterpolationMode.BILINEAR
)
Create a torch DataLoader to train PyTorch models
from tqdm.auto import tqdm
from torch.utils.data.dataloader import DataLoaderbatch_size = 100
training_dl = DataLoader(training_ds, batch_size=batch_size, num_workers=8, worker_init_fn=s3dataset_worker_init_fn)features = []
targets = []
for i, (x, y, _) in tqdm(enumerate(training_dl)):
if i >= 1000:
break
b, c, h, w = x.shape
x_t = model_transforms(torch.tile(x.reshape(-1, 1, h, w), (1, 3, 1, 1)))
if torch.cuda.is_available():
x_t = x_t.cuda()
with torch.no_grad():
x_out = model(x_t)
x_out = x_out.detach().cpu().reshape(-1, c, 576, 7, 7).sum(dim=1)
x_out = org_avgpool(x_out).detach().reshape(b, -1)
features.append(x_out)
targets.append(y)
if (i + 1) % 100 == 0:
features = torch.cat(features, dim=0)
# The labels are mapped as NONE/DMSO = 0, ORF = 1, CRISPS = 2, and COMPUND = 3
targets = torch.cat(targets, dim=0)
torch.save(dict(features=features, targets=targets), f"trn_features_{i // 100:03d}.pt")
print("Saved features checkpoint", f"trn_features_{i // 100:03d}.pt", features.shape, targets.shape)
features = []
targets = []Saved features checkpoint trn_features_000.pt torch.Size([10000, 576]) torch.Size([10000])
Saved features checkpoint trn_features_001.pt torch.Size([10000, 576]) torch.Size([10000])
Saved features checkpoint trn_features_002.pt torch.Size([10000, 576]) torch.Size([10000])
Saved features checkpoint trn_features_003.pt torch.Size([10000, 576]) torch.Size([10000])
Saved features checkpoint trn_features_004.pt torch.Size([10000, 576]) torch.Size([10000])
Saved features checkpoint trn_features_005.pt torch.Size([10000, 576]) torch.Size([10000])
Saved features checkpoint trn_features_006.pt torch.Size([10000, 576]) torch.Size([10000])
Saved features checkpoint trn_features_007.pt torch.Size([10000, 576]) torch.Size([10000])
Saved features checkpoint trn_features_008.pt torch.Size([10000, 576]) torch.Size([10000])
Saved features checkpoint trn_features_009.pt torch.Size([10000, 576]) torch.Size([10000])
val_features = []
val_targets = []
validation_dl = DataLoader(validation_ds, batch_size=100, num_workers=8, worker_init_fn=s3dataset_worker_init_fn)
for i, (x, y, _) in tqdm(enumerate(validation_dl)):
if i >= 50:
break
b, c, h, w = x.shape
x_t = model_transforms(torch.tile(x.reshape(-1, 1, h, w), (1, 3, 1, 1)))
if torch.cuda.is_available():
x_t = x_t.cuda()
with torch.no_grad():
x_out = model(x_t)
x_out = x_out.detach().reshape(-1, c, 576, 7, 7).sum(dim=1)
x_out = org_avgpool(x_out).detach().reshape(b, -1)
val_features.append(x_out)
val_targets.append(y)
val_features = torch.cat(val_features, dim=0)
# The labels are mapped as NONE/DMSO = 0, ORF = 1, CRISPS = 2, and COMPUND = 3
val_targets = torch.cat(val_targets, dim=0)Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
self._shutdown_workers()
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
if w.is_alive():
^^^^^^^^^^^^
File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
assert self._parent_pid == os.getpid(), 'can only test a child process'
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
self._shutdown_workers()
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
if w.is_alive():
^^^^^^^^^^^^
File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
assert self._parent_pid == os.getpid(), 'can only test a child process'
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
self._shutdown_workers()
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
if w.is_alive():
^^^^^^^^^^^^
File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
assert self._parent_pid == os.getpid(), 'can only test a child process'
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
self._shutdown_workers()
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
if w.is_alive():
^^^^^^^^^^^^
File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
assert self._parent_pid == os.getpid(), 'can only test a child process'
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
self._shutdown_workers()
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
if w.is_alive():
^^^^^^^^^^^^
File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
assert self._parent_pid == os.getpid(), 'can only test a child process'
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
self._shutdown_workers()
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
if w.is_alive():
^^^^^^^^^^^^
File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
assert self._parent_pid == os.getpid(), 'can only test a child process'
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
self._shutdown_workers()
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
if w.is_alive():
^^^^^^^^^^^^
File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
assert self._parent_pid == os.getpid(), 'can only test a child process'
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
self._shutdown_workers()
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
if w.is_alive():
^^Exception ignored in: ^^<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>^
^Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^^Exception ignored in: ^self._shutdown_workers()^<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
if w.is_alive(): File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
self._shutdown_workers()
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
assert self._parent_pid == os.getpid(), 'can only test a child process'
if w.is_alive():
^ ^ ^ ^ ^^ ^ ^^^^^^^^^^^^^^^^^^^^^
^ File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^^ ^^assert self._parent_pid == os.getpid(), 'can only test a child process'^^
^^
^ File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^ ^ assert self._parent_pid == os.getpid(), 'can only test a child process'^ ^
^ ^ ^ ^ ^ ^ ^^ ^ ^^^ ^ ^^ ^^ ^^^^^
^^AssertionError^^^: ^^can only test a child process
^^^^^^^^Exception ignored in: ^^<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>^^^
^^Traceback (most recent call last):
^^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^^ ^^self._shutdown_workers()^^
^^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^^^ ^^if w.is_alive():^^
^ ^^^ ^^ ^^^
^ AssertionError^ : ^ can only test a child process^^
^^
^AssertionError^: Exception ignored in: ^can only test a child process<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
^
^Traceback (most recent call last):
^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^ ^Exception ignored in: self._shutdown_workers()^
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>^
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
assert self._parent_pid == os.getpid(), 'can only test a child process'if w.is_alive():
self._shutdown_workers()
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
if w.is_alive():
^ ^ ^ ^^^ ^ ^^ ^^^^^^^^^^^^^^^^
^^^ File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^^ ^^assert self._parent_pid == os.getpid(), 'can only test a child process'^^
^ ^^ ^^
^ File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^ ^ ^ assert self._parent_pid == os.getpid(), 'can only test a child process' ^
^ ^^^ ^^ ^^ ^ ^ ^^ ^^^^ ^^^ ^
^ AssertionError^^: ^^^can only test a child process^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^AssertionError^: ^can only test a child process^
^^^Exception ignored in: ^<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
AssertionErrorTraceback (most recent call last):
: File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
can only test a child process
self._shutdown_workers()
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>if w.is_alive():
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
self._shutdown_workers()
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
if w.is_alive():^
^ ^ ^ ^ ^ ^^ ^^^^^^^^^
^ File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^^ assert self._parent_pid == os.getpid(), 'can only test a child process'^^
^^^^^^^^
^ File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^^assert self._parent_pid == os.getpid(), 'can only test a child process'
^ ^ ^ ^ ^ ^ ^^ ^^^ ^ ^^^^^^^^^^^^^^^^^
AssertionError: ^can only test a child process^
^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
self._shutdown_workers()
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
if w.is_alive():
^^^^^^^^^^^^
File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
assert self._parent_pid == os.getpid(), 'can only test a child process'
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
self._shutdown_workers()
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
if w.is_alive():
^^^^^^^^^^^^
File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
assert self._parent_pid == os.getpid(), 'can only test a child process'
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
self._shutdown_workers()
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
if w.is_alive():
^^^^^^^^^^^^
File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
assert self._parent_pid == os.getpid(), 'can only test a child process'
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
self._shutdown_workers()
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
if w.is_alive():
^^^^^^^^^^^^
File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
assert self._parent_pid == os.getpid(), 'can only test a child process'
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
self._shutdown_workers()
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
if w.is_alive():
^^^^^^^^^^^^
File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
assert self._parent_pid == os.getpid(), 'can only test a child process'
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
self._shutdown_workers()
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
if w.is_alive():
^^^^^^^^^^^^
File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
assert self._parent_pid == os.getpid(), 'can only test a child process'
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
self._shutdown_workers()
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
if w.is_alive():
^^^^^^^^^^^^
File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
assert self._parent_pid == os.getpid(), 'can only test a child process'
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
self._shutdown_workers()
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
if w.is_alive():
^^^^^^^^^^^^
File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
assert self._parent_pid == os.getpid(), 'can only test a child process'
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
self._shutdown_workers()
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
if w.is_alive():
^^^^^^^^^^^^
File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
assert self._parent_pid == os.getpid(), 'can only test a child process'
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-162003/Dest210628-162003_G23_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-162003/Dest210628-162003_L14_T0001F007L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-162003/Dest210628-162003_C06_T0001F007L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-162003/Dest210628-162003_N23_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-162003/Dest210628-162003_P12_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-162003/Dest210628-162003_I11_T0001F007L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-162003/Dest210628-162003_J03_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-162003/Dest210628-162003_F11_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-162003/Dest210628-162003_M11_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-162003/Dest210628-162003_J16_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-162003/Dest210628-162003_C18_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-162003/Dest210628-162003_P06_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-162003/Dest210628-162003_D18_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-162003/Dest210628-162003_K21_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-162003/Dest210628-162003_L12_T0001F009L01A01Z01C01.tif
torch.save(dict(features=val_features, targets=val_targets), "val_features.pt")tst_features = []
tst_targets = []
testing_dl = DataLoader(testing_ds, batch_size=batch_size, num_workers=8, worker_init_fn=s3dataset_worker_init_fn)
for i, (x, y, _) in tqdm(enumerate(testing_dl)):
if i >= 50:
break
b, c, h, w = x.shape
x_t = model_transforms(torch.tile(x.reshape(-1, 1, h, w), (1, 3, 1, 1)))
if torch.cuda.is_available():
x_t = x_t.cuda()
with torch.no_grad():
x_out = model(x_t)
x_out = x_out.detach().reshape(-1, c, 576, 7, 7).sum(dim=1)
x_out = org_avgpool(x_out).detach().reshape(b, -1)
tst_features.append(x_out)
tst_targets.append(y)
tst_features = torch.cat(tst_features, dim=0)
# The labels are mapped as NONE/DMSO = 0, ORF = 1, CRISPS = 2, and COMPUND = 3
tst_targets = torch.cat(tst_targets, dim=0)Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0><function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Exception ignored in: Exception ignored in: Exception ignored in:
Traceback (most recent call last):
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0><function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>Exception ignored in: Traceback (most recent call last):
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>self._shutdown_workers()self._shutdown_workers()Exception ignored in: Exception ignored in: Traceback (most recent call last):
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0> File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
Traceback (most recent call last):
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
self._shutdown_workers()if w.is_alive():self._shutdown_workers()Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
if w.is_alive():self._shutdown_workers()
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
AssertionError File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
self._shutdown_workers()self._shutdown_workers() :
if w.is_alive():
if w.is_alive(): if w.is_alive(): can only test a child process File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
if w.is_alive():if w.is_alive(): ^
^ ^ ^ ^ ^^ ^ ^ ^ ^^^^ ^^ ^^ ^^^^ ^ ^^^^^^^^ ^^^^^^^^^
^^^^^ File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^^^^^^ ^^^^^
assert self._parent_pid == os.getpid(), 'can only test a child process'^^^^^^
^ File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^^^^^^^
^^ File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^^
^ File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^assert self._parent_pid == os.getpid(), 'can only test a child process' File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^
assert self._parent_pid == os.getpid(), 'can only test a child process'^
^assert self._parent_pid == os.getpid(), 'can only test a child process' File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^ ^assert self._parent_pid == os.getpid(), 'can only test a child process' ^
assert self._parent_pid == os.getpid(), 'can only test a child process'
^^ ^ ^ ^ ^ ^ ^ ^ ^ ^^^^^^
^^^ ^ ^ ^ ^ ^ ^^^^^^ ^ ^^ File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^ ^^^ ^^^^^^assert self._parent_pid == os.getpid(), 'can only test a child process'^^^ ^^^^^^^^
^^^^^^^^ ^^^^ ^^^^^ ^^^^^^^^^^^^^^^ ^^^^^ ^^^^^ ^^^^^ ^^^^^^ ^^^^^^^^^^^^^^^ ^^^ ^ ^^^^^^^^^^^^^^^^^^^^
^^^AssertionError^^^^^^: ^^^can only test a child process^^^
^^AssertionError^^^^: ^^^^can only test a child process^
^
^^
^^^^^AssertionError^^^:
^^^^AssertionErrorcan only test a child process^: ^^can only test a child process
^^
^^^^^^^
AssertionError:
can only test a child process
^AssertionError: ^can only test a child process^
^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
self._shutdown_workers()
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
if w.is_alive():
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
Traceback (most recent call last):
Exception ignored in: File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0> ^
Traceback (most recent call last):
self._shutdown_workers() File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^
^^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^Exception ignored in: ^<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>self._shutdown_workers()if w.is_alive():
^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^ ^ ^ ^
if w.is_alive(): Traceback (most recent call last):
^
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
Exception ignored in:
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0> File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
self._shutdown_workers()assert self._parent_pid == os.getpid(), 'can only test a child process'
^
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^ Traceback (most recent call last):
^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^if w.is_alive():^
^self._shutdown_workers() ^Exception ignored in:
^<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0> ^^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^
^ Traceback (most recent call last):
^ if w.is_alive(): ^
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^ ^ ^
^ self._shutdown_workers()^ File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^^ ^
^ ^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^ ^assert self._parent_pid == os.getpid(), 'can only test a child process' ^ ^^^
^if w.is_alive(): ^^^ ^
^
File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^^ ^ ^^^^assert self._parent_pid == os.getpid(), 'can only test a child process' ^ ^^
^^^^ ^^
^ ^ File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^ ^^ assert self._parent_pid == os.getpid(), 'can only test a child process'^ ^ ^ ^^
^^^ ^^^ ^^^ ^ ^^ ^^ ^^
^^^ Exception ignored in: ^^ File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^ ^<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0> ^^ ^^assert self._parent_pid == os.getpid(), 'can only test a child process'
^
^^ ^Traceback (most recent call last):
^^^ ^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^^^ ^^ ^ ^^^self._shutdown_workers()
^^ ^AssertionError^
^: ^^ ^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^ can only test a child process
^ File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^ ^ ^^ ^assert self._parent_pid == os.getpid(), 'can only test a child process'if w.is_alive():^^
^
^^^ ^^^^ ^^^^ ^ ^^ ^^^Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>^^^ ^
^ ^^Traceback (most recent call last):
^^ ^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^^ ^ ^ ^ ^^self._shutdown_workers() ^^
^^^ ^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^^^^^^^^^^^ ^^^^^^^^^if w.is_alive():^^^^^
^^ ^^^^^ ^^^^^ ^^^^^ ^^^^^^ ^^^^^
^ ^^^^ AssertionError
^^^AssertionError^
^^: : ^^^ File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^can only test a child process^can only test a child process^
^^^
^^assert self._parent_pid == os.getpid(), 'can only test a child process'
^^Exception ignored in: ^^^^ ^<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0> ^^ Exception ignored in: ^^ <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>^ ^ ^^
^AssertionErrorTraceback (most recent call last):
:
^
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^ File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^ Traceback (most recent call last):
^ can only test a child process^^ self._shutdown_workers()
assert self._parent_pid == os.getpid(), 'can only test a child process' File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^^
^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
AssertionError^self._shutdown_workers()Exception ignored in: ^Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
^if w.is_alive():^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
:
Traceback (most recent call last):
can only test a child process ^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^if w.is_alive():self._shutdown_workers()^ ^
self._shutdown_workers()
Exception ignored in: ^
^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^^^ ^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>^ ^
^ if w.is_alive():AssertionErrorif w.is_alive(): ^
^ ^:
Traceback (most recent call last):
^^can only test a child process ^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^ ^ ^^
^^ ^self._shutdown_workers()^ ^^
^^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^ ^^^ Exception ignored in: ^^^ ^^^if w.is_alive(): ^ <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>^^^^^^
^ ^^ ^^^^
^^^^Traceback (most recent call last):
^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^^ ^^^ ^^^^ ^self._shutdown_workers()^^ ^^^^^
^^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^^^^^ ^^ File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^^^^if w.is_alive():^^^^
^^ ^
^^ File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^^ File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
assert self._parent_pid == os.getpid(), 'can only test a child process'^ ^^
^^ ^^ assert self._parent_pid == os.getpid(), 'can only test a child process'assert self._parent_pid == os.getpid(), 'can only test a child process'^^^ ^
^
^ ^ File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^ ^AssertionError^ : ^^ assert self._parent_pid == os.getpid(), 'can only test a child process'
^
can only test a child process File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^ ^ ^ assert self._parent_pid == os.getpid(), 'can only test a child process' ^ ^
Exception ignored in:
^^<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0> AssertionError^
^ : ^ ^can only test a child process Traceback (most recent call last):
^ ^^
^^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^ ^^^^ ^^^^ ^^ ^^^Exception ignored in: ^ ^
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>self._shutdown_workers()^^^^ File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^
^
^ Traceback (most recent call last):
^^^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^assert self._parent_pid == os.getpid(), 'can only test a child process'^^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^
^ ^^ ^^^if w.is_alive(): ^self._shutdown_workers()
^ ^
^^^ ^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^^^ ^^ if w.is_alive():^ ^^ ^^^
^ ^ ^ ^ ^^^ ^^^^ ^^^ ^^^ ^^^^ ^^ ^^^ ^^^^ ^^^ ^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^AssertionError^^^^^
: ^^ File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^^^^can only test a child process^^ ^^
^^^AssertionError^^assert self._parent_pid == os.getpid(), 'can only test a child process'^^^:
^^can only test a child process^^
^AssertionError
^^: Exception ignored in: AssertionError File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
: ^<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>can only test a child process^ can only test a child process
assert self._parent_pid == os.getpid(), 'can only test a child process'
Exception ignored in: ^
^ <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0> ^
^ Exception ignored in: Traceback (most recent call last):
^Traceback (most recent call last):
^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^ <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>Exception ignored in: File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
AssertionError^
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0> : Traceback (most recent call last):
self._shutdown_workers()
^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^can only test a child process
self._shutdown_workers() ^Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^^ if w.is_alive(): ^^if w.is_alive():^
^self._shutdown_workers()^^
Exception ignored in: ^^ <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>
self._shutdown_workers() ^
^^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
Traceback (most recent call last):
^ ^
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^^AssertionError ^if w.is_alive():^:
self._shutdown_workers()
^^can only test a child process File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^^
^
File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^^if w.is_alive(): ^^^
^ ^Exception ignored in: ^ ^^^if w.is_alive(): ^<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>^
^^^
Traceback (most recent call last):
^^ ^ ^^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^^^^ ^^^^^ ^^ ^ ^^^ ^^ ^^^^^
^^^^^^ File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^^self._shutdown_workers()^ ^^^^
^assert self._parent_pid == os.getpid(), 'can only test a child process'^^^
^^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^^^if w.is_alive():^^^
^^
^ File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^ ^^ ^^^^^^
^ ^assert self._parent_pid == os.getpid(), 'can only test a child process' ^ File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^
^^ ^^^^ ^^ ^assert self._parent_pid == os.getpid(), 'can only test a child process'^ ^
^^^ ^^ ^
^ ^
^^AssertionError
^ File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^ ^
File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
: ^ AssertionError^^assert self._parent_pid == os.getpid(), 'can only test a child process'can only test a child process : ^ ^
can only test a child process ^
^
^assert self._parent_pid == os.getpid(), 'can only test a child process'^ ^^ ^^
^^ ^ ^^ ^ Exception ignored in: ^^^ <function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>^^^
^ File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^
^ ^ ^Traceback (most recent call last):
^assert self._parent_pid == os.getpid(), 'can only test a child process'^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^^ ^
^ ^^ ^^ ^ ^self._shutdown_workers() ^^^ ^
^^^^ ^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^^ ^^^^ ^^^if w.is_alive():^^ ^
^^^ ^ ^^^^ ^^ ^^^^ ^ ^^ ^^^ ^^ ^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^AssertionError
^: ^^^AssertionError^^: ^can only test a child process^^^^^^can only test a child process
^^^^
^^^^^^^^^
^^^^^^^ File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^^^^ ^^
^assert self._parent_pid == os.getpid(), 'can only test a child process'^^^
AssertionError^: ^ ^
can only test a child process AssertionError
^ ^: can only test a child process^^
^ ^^^ ^ ^Exception ignored in:
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fffa43334c0>^
AssertionError ^Traceback (most recent call last):
: ^ ^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
can only test a child process^^ ^
self._shutdown_workers()^^^
^^^ File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
if w.is_alive():^AssertionError
: ^can only test a child process
^ ^^^^^^^^^^^^^^^^^^^^^^
^ File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^ ^assert self._parent_pid == os.getpid(), 'can only test a child process'^
^ ^^ ^^ ^ ^ ^ ^ ^
AssertionError^: can only test a child process^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_L23_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_L03_T0001F007L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_O07_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_N12_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_G18_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_I07_T0001F007L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_P09_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_L10_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_C17_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_I05_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_A08_T0001F007L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_K03_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_I02_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_N16_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_E02_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_K12_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_F24_T0001F007L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_J01_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_A17_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_L17_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_I02_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_B16_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_F12_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_O08_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_K02_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_D17_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_M21_T0001F007L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_D20_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_D22_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_M21_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_C20_T0001F009L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_K03_T0001F008L01A01Z01C01.tif
Failed retrieving: s3://cellpainting-gallery/cpg0016-jump/source_10/images/2021_06_28_U2OS_48_hr_run9/images/Dest210628-161651/Dest210628-161651_M24_T0001F007L01A01Z01C01.tif
torch.save(dict(features=tst_features, targets=tst_targets), "tst_features.pt")Set up the model training as an optimization problem
classifier = torch.nn.Sequential(
torch.nn.Linear(in_features=576, out_features=2, bias=False),
torch.nn.ReLU(),
torch.nn.Linear(in_features=2, out_features=3, bias=False)
)if torch.cuda.is_available():
classifier.cuda()optimizer = torch.optim.SGD(classifier.parameters(), lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()trn_feat_dl = DataLoader(list(zip(features, targets)), batch_size=100, shuffle=True)
val_feat_dl = DataLoader(list(zip(val_features, val_targets)), batch_size=100, shuffle=False)avg_loss_trn = []
avg_acc_trn = []
avg_loss_val = []
avg_acc_val = []
n_epochs = 100
q = tqdm(total=n_epochs)
for e in range(n_epochs):
n_dmso = 0
n_crispr = 0
n_orf = 0
# Training loop
classifier.train()
loss_epoch = 0
acc_epoch = 0
for x, y in trn_feat_dl:
optimizer.zero_grad()
if torch.cuda.is_available():
x = x.cuda()
y = y.cuda()
y_pred = classifier(x.squeeze())
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
loss_epoch += loss.item()
acc_epoch += torch.sum(y_pred.argmax(dim=1) == y) / len(y)
n_dmso += sum(y == 0)
n_crispr += sum(y == 1)
n_orf += sum(y == 2)
avg_loss_trn.append(loss_epoch / len(trn_feat_dl))
avg_acc_trn.append(acc_epoch / len(trn_feat_dl))
n_total = n_dmso + n_crispr + n_orf
trn_class_props = [n_dmso / n_total, n_crispr / n_total, n_orf / n_total]
# Validation loop
classifier.eval()
n_dmso = 0
n_crispr = 0
n_orf = 0
loss_epoch = 0
acc_epoch = 0
for x_val, y_val in val_feat_dl:
with torch.no_grad():
if torch.cuda.is_available():
x = x.cuda()
y = y.cuda()
y_val_pred = classifier(x_val.squeeze())
loss = loss_fn(y_val_pred, y_val)
loss_epoch += loss.item()
acc_epoch += torch.sum(y_val_pred.argmax(dim=1) == y_val) / len(y_val)
n_dmso += sum(y == 0)
n_crispr += sum(y == 1)
n_orf += sum(y == 2)
avg_loss_val.append(loss_epoch / len(val_feat_dl))
avg_acc_val.append(acc_epoch / len(val_feat_dl))
n_total = n_dmso + n_crispr + n_orf
val_class_props = [n_dmso / n_total, n_crispr / n_total, n_orf / n_total]
q.set_description(f"Average training loss: {avg_loss_trn[-1]:0.4f} (Accuracy: {100 * avg_acc_trn[-1]:0.4f} %). Average validation loss: {avg_loss_val[-1]:04f} (Accuracy: {100 * avg_acc_val[-1]:0.4f} %)")
q.update()trn_class_props, val_class_propsimport matplotlib.pyplot as pltplt.plot(avg_loss_trn, "k-", label="Training loss")
plt.plot(avg_loss_val, "b:", label="Validation loss")
plt.legend()plt.plot(avg_acc_trn, "k-", label="Training accuracy")
plt.plot(avg_acc_val, "b:", label="Validation accuracy")
plt.legend()trn_feat_dl = DataLoader(list(zip(features, targets)), batch_size=2000, shuffle=True)
val_feat_dl = DataLoader(list(zip(val_features, val_targets)), batch_size=2000, shuffle=True)
x_trn, y_trn = next(iter(trn_feat_dl))
x_val, y_val = next(iter(val_feat_dl))
classifier.eval()
with torch.no_grad():
fx_trn = classifier[0](x_trn)
fx_val = classifier[0](x_val)fx_trn.shapeclass_names = ["NONE/DMSO", "CRISPR", "ORF", "COMPUND"]markers = ['o', 's', '^', 'v']
for y_idx, class_name in enumerate(class_names):
plt.scatter(x=fx_trn[y_trn == y_idx, 0], y=fx_trn[y_trn == y_idx, 1], marker=markers[y_idx], label=class_name)
plt.legend()
plt.show()for y_idx, class_name in enumerate(class_names):
plt.scatter(x=fx_val[y_val == y_idx, 0], y=fx_val[y_val == y_idx, 1], marker=markers[y_idx], label=class_name)
plt.legend()
plt.show()Check what compounds are similar according to their phenotipic profile
s3 = s3fs.S3FileSystem(anon=True)comp_plate_mapscompound_wells_metadatacomp_plate_map = comp_plate_maps.iloc[[0]]
comp_plate_mapLoad the image from the AWS bucket
x_comp, y_comp = load_well(comp_plate_map, wells_metadata, 1, 1, 0, 5, s3)Add a dummy axis to treat a single sample as a batch of size one
x_comp = torch.from_numpy(x_comp[None, ...])x_comp.shape, x_comp.dtype, y_compExtract features with the baseline model
b, c, h, w = x_comp.shape
x_comp_t = model_transforms(torch.tile(x_comp.reshape(-1, 1, h, w), (1, 3, 1, 1)))
with torch.no_grad():
x_out = model(x_comp_t)
x_out = x_out.detach().reshape(-1, c, 576, 7, 7).sum(dim=1)
x_out = org_avgpool(x_out).detach().reshape(b, -1)Predict the type of perturbation with the classifier model
classifier.eval()
with torch.no_grad():
y_pred_comp = classifier(x_out)
fx_comp = classifier[0](x_out)y_pred_comp.argmax(), class_names[y_pred_comp.argmax().item()], class_names[y_comp]markers = ['o', 's', '^', 'v']
for y_idx, class_name in enumerate(class_names):
plt.scatter(x=fx_trn[y_trn == y_idx, 0], y=fx_trn[y_trn == y_idx, 1], marker=markers[y_idx], label=class_name)
plt.scatter(x=fx_comp[0, 0], y=fx_comp[0, 1], marker="x", label="Test")
plt.legend()
plt.show()